{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Get started with Environments, TED and transforms\n",
    "\n",
    "**Author**: [Vincent Moens](https://github.com/vmoens)\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\"><h4>Note</h4><p>To run this tutorial in a notebook, add an installation cell\n",
    "  at the beginning containing:\n",
    "\n",
    "```\n",
    "!pip install tensordict\n",
    "!pip install torchrl</p></div>\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Welcome to the getting started tutorials!\n",
    "\n",
    "Below is the list of the topics we will be covering.\n",
    "\n",
    "- `Environments, TED and transforms <gs_env_ted>`;\n",
    "- `TorchRL's modules <gs_modules>`;\n",
    "- `Losses and optimization <gs_optim>`;\n",
    "- `Data collection and storage <gs_storage>`;\n",
    "- `TorchRL's logging API <gs_logging>`.\n",
    "\n",
    "If you are in a hurry, you can jump straight away to the last tutorial,\n",
    "`Your own first training loop <gs_first_training>`, from where you can\n",
    "backtrack every other \"Getting Started\" tutorial if things are not clear or\n",
    "if you want to learn more about a specific topic!\n",
    "\n",
    "## Environments in RL\n",
    "\n",
    "The standard RL (Reinforcement Learning) training loop involves a model,\n",
    "also known as a policy, which is trained to accomplish a task within a\n",
    "specific environment. Often, this environment is a simulator that accepts\n",
    "actions as input and produces an observation along with some metadata as\n",
    "output.\n",
    "\n",
    "In this document, we will explore the environment API of TorchRL: we will\n",
    "learn how to create an environment, interact with it, and understand the\n",
    "data format it uses.\n",
    "\n",
    "## Creating an environment\n",
    "\n",
    "In essence, TorchRL does not directly provide environments, but instead\n",
    "offers wrappers for other libraries that encapsulate the simulators. The\n",
    ":mod:`~torchrl.envs` module can be viewed as a provider for a generic\n",
    "environment API, as well as a central hub for simulation backends like\n",
    "[gym](https://arxiv.org/abs/1606.01540) (:class:`~torchrl.envs.GymEnv`),\n",
    "[Brax](https://arxiv.org/abs/2106.13281) (:class:`~torchrl.envs.BraxEnv`)\n",
    "or [DeepMind Control Suite](https://arxiv.org/abs/1801.00690)\n",
    "(:class:`~torchrl.envs.DMControlEnv`).\n",
    "\n",
    "Creating your environment is typically as straightforward as the underlying\n",
    "backend API allows. Here's an example using gym:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from torchrl.envs import GymEnv\n",
    "\n",
    "env = GymEnv(\"Pendulum-v1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running an environment\n",
    "\n",
    "Environments in TorchRL have two crucial methods:\n",
    ":meth:`~torchrl.envs.EnvBase.reset`, which initiates\n",
    "an episode, and :meth:`~torchrl.envs.EnvBase.step`, which executes an\n",
    "action selected by the actor.\n",
    "In TorchRL, environment methods read and write\n",
    ":class:`~tensordict.TensorDict` instances.\n",
    "Essentially, :class:`~tensordict.TensorDict` is a generic key-based data\n",
    "carrier for tensors.\n",
    "The benefit of using TensorDict over plain tensors is that it enables us to\n",
    "handle simple and complex data structures interchangeably. As our function\n",
    "signatures are very generic, it eliminates the challenge of accommodating\n",
    "different data formats. In simpler terms, after this brief tutorial,\n",
    "you will be capable of operating on both simple and highly complex\n",
    "environments, as their user-facing API is identical and simple!\n",
    "\n",
    "Let's put the environment into action and see what a tensordict instance\n",
    "looks like:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "reset = env.reset()\n",
    "print(reset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's take a random action in the action space. First, sample the action:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "reset_with_action = env.rand_action(reset)\n",
    "print(reset_with_action)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tensordict has the same structure as the one obtained from\n",
    ":meth:`~torchrl.envs.EnvBase` with an additional ``\"action\"`` entry.\n",
    "You can access the action easily, like you would do with a regular\n",
    "dictionary:\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.4607])\n"
     ]
    }
   ],
   "source": [
    "print(reset_with_action[\"action\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to pass this action tp the environment.\n",
    "We'll be passing the entire tensordict to the ``step`` method, since there\n",
    "might be more than one tensor to be read in more advanced cases like\n",
    "Multi-Agent RL or stateless environments:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        next: TensorDict(\n",
      "            fields={\n",
      "                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "            batch_size=torch.Size([]),\n",
      "            device=cpu,\n",
      "            is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "stepped_data = env.step(reset_with_action)\n",
    "print(stepped_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, this new tensordict is identical to the previous one except for the\n",
    "fact that it has a ``\"next\"`` entry (itself a tensordict!) containing the\n",
    "observation, reward and done state resulting from\n",
    "our action.\n",
    "\n",
    "We call this format TED, for\n",
    "`TorchRL Episode Data format <TED-format>`. It is\n",
    "the ubiquitous way of representing data in the library, both dynamically like\n",
    "here, or statically with offline datasets.\n",
    "\n",
    "The last bit of information you need to run a rollout in the environment is\n",
    "how to bring that ``\"next\"`` entry at the root to perform the next step.\n",
    "TorchRL provides a dedicated :func:`~torchrl.envs.utils.step_mdp` function\n",
    "that does just that: it filters out the information you won't need and\n",
    "delivers a data structure corresponding to your observation after a step in\n",
    "the Markov Decision Process, or MDP.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "from torchrl.envs import step_mdp\n",
    "\n",
    "data = step_mdp(stepped_data)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Environment rollouts\n",
    "\n",
    "\n",
    "Writing down those three steps (computing an action, making a step,\n",
    "moving in the MDP) can be a bit tedious and repetitive. Fortunately,\n",
    "TorchRL provides a nice :meth:`~torchrl.envs.EnvBase.rollout` function that\n",
    "allows you to run them in a closed loop at will:\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        next: TensorDict(\n",
      "            fields={\n",
      "                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "            batch_size=torch.Size([10]),\n",
      "            device=cpu,\n",
      "            is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([10]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "rollout = env.rollout(max_steps=10)\n",
    "print(rollout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This data looks pretty much like the ``stepped_data`` above with the\n",
    "exception of its batch-size, which now equates the number of steps we\n",
    "provided through the ``max_steps`` argument. The magic of tensordict\n",
    "doesn't end there: if you're interested in a single transition of this\n",
    "environment, you can index the tensordict like you would index a tensor:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        next: TensorDict(\n",
      "            fields={\n",
      "                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "            batch_size=torch.Size([]),\n",
      "            device=cpu,\n",
      "            is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "transition = rollout[3]\n",
    "print(transition)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ":class:`~tensordict.TensorDict` will automatically check if the index you\n",
    "provided is a key (in which case we index along the key-dimension) or a\n",
    "spatial index like here.\n",
    "\n",
    "Executed as such (without a policy), the ``rollout`` method may seem rather\n",
    "useless: it just runs random actions. If a policy is available, it can\n",
    "be passed to the method and used to collect data.\n",
    "\n",
    "Nevertheless, it can useful to run a naive, policyless rollout at first to\n",
    "check what is to be expected from an environment at a glance.\n",
    "\n",
    "To appreciate the versatility of TorchRL's API, consider the fact that the\n",
    "rollout method is universally applicable. It functions across **all** use\n",
    "cases, whether you're working with a single environment like this one,\n",
    "multiple copies across various processes, a multi-agent environment, or even\n",
    "a stateless version of it!\n",
    "\n",
    "\n",
    "## Transforming an environment\n",
    "\n",
    "Most of the time, you'll want to modify the output of the environment to\n",
    "better suit your requirements. For example, you might want to monitor the\n",
    "number of steps executed since the last reset, resize images, or stack\n",
    "consecutive observations together.\n",
    "\n",
    "In this section, we'll examine a simple transform, the\n",
    ":class:`~torchrl.envs.transforms.StepCounter` transform.\n",
    "The complete list of transforms can be found\n",
    "`here <transforms>`.\n",
    "\n",
    "The transform is integrated with the environment through a\n",
    ":class:`~torchrl.envs.transforms.TransformedEnv`:\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorDict(\n",
      "    fields={\n",
      "        action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        next: TensorDict(\n",
      "            fields={\n",
      "                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "                step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),\n",
      "                terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "                truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "            batch_size=torch.Size([10]),\n",
      "            device=cpu,\n",
      "            is_shared=False),\n",
      "        observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),\n",
      "        step_count: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),\n",
      "        terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),\n",
      "        truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},\n",
      "    batch_size=torch.Size([10]),\n",
      "    device=cpu,\n",
      "    is_shared=False)\n"
     ]
    }
   ],
   "source": [
    "from torchrl.envs import StepCounter, TransformedEnv\n",
    "\n",
    "transformed_env = TransformedEnv(env, StepCounter(max_steps=10))\n",
    "rollout = transformed_env.rollout(max_steps=100)\n",
    "print(rollout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, our environment now has one more entry, ``\"step_count\"`` that\n",
    "tracks the number of steps since the last reset.\n",
    "Given that we passed the optional\n",
    "argument ``max_steps=10`` to the transform constructor, we also truncated the\n",
    "trajectory after 10 steps (not completing a full rollout of 100 steps like\n",
    "we asked with the ``rollout`` call). We can see that the trajectory was\n",
    "truncated by looking at the truncated entry:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [False],\n",
      "        [ True]])\n"
     ]
    }
   ],
   "source": [
    "print(rollout[\"next\", \"truncated\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is all for this short introduction to TorchRL's environment API!\n",
    "\n",
    "## Next steps\n",
    "\n",
    "To explore further what TorchRL's environments can do, go and check:\n",
    "\n",
    "- The :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` method that packs\n",
    "  together :meth:`~torchrl.envs.EnvBase.step`,\n",
    "  :func:`~torchrl.envs.utils.step_mdp` and\n",
    "  :meth:`~torchrl.envs.EnvBase.reset`.\n",
    "- Some environments like :class:`~torchrl.envs.GymEnv` support rendering\n",
    "  through the ``from_pixels`` argument. Check the class docstrings to know\n",
    "  more!\n",
    "- The batched environments, in particular :class:`~torchrl.envs.ParallelEnv`\n",
    "  which allows you to run multiple copies of one same (or different!)\n",
    "  environments on multiple processes.\n",
    "- Design your own environment with the\n",
    "  `Pendulum tutorial <pendulum_tuto>` and learn about specs and\n",
    "  stateless environments.\n",
    "- See the more in-depth tutorial about environments\n",
    "  `in the dedicated tutorial <envs_tuto>`;\n",
    "- Check the\n",
    "  `multi-agent  environment API <MARL-environment-API>`\n",
    "  if you're interested in MARL;\n",
    "- TorchRL has many tools to interact with the Gym API such as\n",
    "  a way to register TorchRL envs in the Gym register through\n",
    "  :meth:`~torchrl.envs.EnvBase.register_gym`, an API to read\n",
    "  the info dictionaries through\n",
    "  :meth:`~torchrl.envs.EnvBase.set_info_dict_reader` or a way\n",
    "  to control the gym backend thanks to\n",
    "  :func:`~torchrl.envs.set_gym_backend`.\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mac-rl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
